from z3 import *
import pickle
import numpy as np
from numba import jit, cuda 

poem_vecs = pickle.load(open("pickles/poem_vecs.pcl", "rb"))

num_relations = 7

poem_sec_scores={}
poem_sec_vecs={}


set_param(timeout=30*1000) #comment out if you want more accuracy but potentially longer runtime

#total number of symmetry relations in 1- lines after i in vec
def getSum(vec, i):
	return np.sum([np.sum(vec[i + a, :10 - a, :]) for a in range(10)])

def getAdj(vec, sec):
	new_vec = np.zeros((10,10,7))
	for i in range(10):
		for j in range(10 - i):
			new_vec[i,i+j,:] = vec[sec + i, j, :]
			new_vec[i+j,i,:] = new_vec[i,i+j,:]
	return new_vec

for (ind, (poem, vec)) in enumerate(poem_vecs.items()):
    lines = [Bool(str(i)) for i in range(len(vec))]
    scores = [getSum(vec, i) for i in range(len(vec) - 10)]

    o = Optimize()

    for i in range(len(vec)):
        c  = Not(And(lines[i], Or(lines[i+1:min(i + 10,len(vec)):])))
        o.add(c)
    try:
        h = o.maximize(Sum([If(lines[i], scores[i], 0) for i in range(len(vec) - 10)]))
        o.check()
        o.upper(h)
        a = o.model()
        secs = [k for k in range(len(vec) - 10) if a.eval(lines[k])]
        print("in x")
    except:
        print("in error")
        continue

    for sec in secs:
        #print(sec)
        poem_sec_vecs[tuple(list(poem)[sec:sec + 10])] = getAdj(vec, sec)
        poem_sec_scores[tuple(list(poem)[sec:sec + 10])] = np.sum(getAdj(vec, sec))

    if ind % 10 == 9:
        pickle.dump(poem_sec_scores, open("pickles/poem_sec_scores.pcl", "wb"))
        pickle.dump(poem_sec_vecs, open("pickles/poem_sec_vecs.pcl", "wb"))
        print(len(poem_sec_vecs))
